{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "EixH_3T4RN9u"
   },
   "source": [
    "# Ungraded lab: Shapley Values\n",
    "------------------------\n",
    "\n",
    "Welcome, during this ungraded lab you are going to be working with SHAP (SHapley Additive exPlanations). This procedure is derived from game theory and aims to understand (or explain) the output of any machine learning model. In particular you will:\n",
    "\n",
    "\n",
    "1. Train a simple CNN on the fashion mnist dataset.\n",
    "2. Compute the Shapley values for examples of each class.\n",
    "3. Visualize these values and derive information from them.\n",
    "\n",
    "To learn more about Shapley Values visit the official [SHAP repo](https://github.com/slundberg/shap).\n",
    "\n",
    "Let's get started!\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "i7j3OLmgUWzB"
   },
   "source": [
    "## Imports\n",
    "\n",
    "Begin by installing the shap library:\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "N7Fzf05Amwpx"
   },
   "outputs": [],
   "source": [
    "!pip install shap"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9SnFiuF_V0MY"
   },
   "source": [
    "Now import all necessary dependencies:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "K5BRI5W3mf5q"
   },
   "outputs": [],
   "source": [
    "import shap\n",
    "import numpy as np\n",
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0fZHtO_oWh3_"
   },
   "source": [
    "## Train a CNN model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "hIB4uc1qhgc9"
   },
   "source": [
    "For this lab you will use the [fashion MNIST](https://keras.io/api/datasets/fashion_mnist/) dataset. Load it and pre-process the data before feeding it into the model:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6ap5dqFQmsDC"
   },
   "outputs": [],
   "source": [
    "# Download the dataset\n",
    "(x_train, y_train), (x_test, y_test) = keras.datasets.fashion_mnist.load_data()\n",
    "\n",
    "# Reshape and normalize data\n",
    "x_train = x_train.reshape(60000, 28, 28, 1).astype(\"float32\") / 255\n",
    "x_test = x_test.reshape(10000, 28, 28, 1).astype(\"float32\") / 255"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CBak9dOqiI2V"
   },
   "source": [
    "For the CNN model you will use a simple architecture composed of a single  convolutional and maxpooling layers pair connected to a fully conected layer with 256 units and the output layer with 10 units since there are 10 categories.\n",
    "\n",
    "Define the model using Keras' [Functional API](https://keras.io/guides/functional_api/):\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "SH7gEwiksWDE"
   },
   "outputs": [],
   "source": [
    "# Define the model architecture using the functional API\n",
    "inputs = keras.Input(shape=(28, 28, 1))\n",
    "x = keras.layers.Conv2D(32, (3, 3), activation='relu')(inputs)\n",
    "x = keras.layers.MaxPooling2D((2, 2))(x)\n",
    "x = keras.layers.Flatten()(x)\n",
    "x = keras.layers.Dense(256, activation='relu')(x)\n",
    "outputs = keras.layers.Dense(10, activation='softmax')(x)\n",
    "\n",
    "# Create the model with the corresponding inputs and outputs\n",
    "model = keras.Model(inputs=inputs, outputs=outputs, name=\"CNN\")\n",
    "\n",
    "# Compile the model\n",
    "model.compile(\n",
    "      loss=tf.keras.losses.SparseCategoricalCrossentropy(),\n",
    "      optimizer=keras.optimizers.Adam(),\n",
    "      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]\n",
    "  )\n",
    "\n",
    "# Train it!\n",
    "model.fit(x_train, y_train, epochs=5, validation_data=(x_test, y_test))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rwSN-8STj_bh"
   },
   "source": [
    "Judging the accuracy metrics looks like the model is overfitting. However, it is achieving a >90% accuracy on the test set so its performance is adequate for the purposes of this lab."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jJv0LghcmX_P"
   },
   "source": [
    "# Explaining the outputs\n",
    "\n",
    "You know that the model is correctly classifying around 90% of the images in the test set. But how is it doing it? What pixels are being used to determine if an image belongs to a particular class?\n",
    "\n",
    "To answer these questions you can use SHAP values.\n",
    "\n",
    "Before doing so, check how each one of the categories looks like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "hSqcehYVxSZY"
   },
   "outputs": [],
   "source": [
    "# Name each one of the classes\n",
    "class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',\n",
    "               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']\n",
    "\n",
    "# Save an example for each category in a dict\n",
    "images_dict = dict()\n",
    "for i, l in enumerate(y_train):\n",
    "  if len(images_dict)==10:\n",
    "    break\n",
    "  if l not in images_dict.keys():\n",
    "    images_dict[l] = x_train[i].reshape((28, 28))\n",
    "\n",
    "# Function to plot images\n",
    "def plot_categories(images):\n",
    "  fig, axes = plt.subplots(1, 11, figsize=(16, 15))\n",
    "  axes = axes.flatten()\n",
    "\n",
    "  # Plot an empty canvas\n",
    "  ax = axes[0]\n",
    "  dummy_array = np.array([[[0, 0, 0, 0]]], dtype='uint8')\n",
    "  ax.set_title(\"reference\")\n",
    "  ax.set_axis_off()\n",
    "  ax.imshow(dummy_array, interpolation='nearest')\n",
    "\n",
    "  # Plot an image for every category\n",
    "  for k,v in images.items():\n",
    "    ax = axes[k+1]\n",
    "    ax.imshow(v, cmap=plt.cm.binary)\n",
    "    ax.set_title(f\"{class_names[k]}\")\n",
    "    ax.set_axis_off()\n",
    "\n",
    "  plt.tight_layout()\n",
    "  plt.show()\n",
    "\n",
    "\n",
    "# Use the function to plot\n",
    "plot_categories(images_dict)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "rPt9q5bDhCGz"
   },
   "source": [
    "Now you know how the items in each one of the categories looks like.\n",
    "\n",
    "You might wonder what the empty image at the left is for. You will see shortly why it is important."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9BjJPAnhhcYS"
   },
   "source": [
    "## DeepExplainer\n",
    "\n",
    "To compute shap values for the model you just trained you will use the `DeepExplainer` class from the `shap` library.\n",
    "\n",
    "To instantiate this class you need to pass in a model along with training examples. Notice that not all of the training examples are passed in but only a fraction of them.\n",
    "\n",
    "This is done because the computations done by the `DeepExplainer` object are very intensive on the RAM and you might run out of it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-fZZQnMWvjc6"
   },
   "outputs": [],
   "source": [
    "# Take a random sample of 5000 training images\n",
    "background = x_train[np.random.choice(x_train.shape[0], 5000, replace=False)]\n",
    "\n",
    "# Use DeepExplainer to explain predictions of the model. You can safely ignore the warning about the Tensorflow version in the output.\n",
    "e = shap.DeepExplainer(model, background)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "YYPuBiOwjEDh"
   },
   "source": [
    "Now you can use the `DeepExplainer` instance to compute Shap values for images on the test set.\n",
    "\n",
    "So you can properly visualize these values for each class, create an array that contains one element of each class from the test set:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5InnVMMzvCwl"
   },
   "outputs": [],
   "source": [
    "# Save an example of each class from the test set\n",
    "x_test_dict = dict()\n",
    "for i, l in enumerate(y_test):\n",
    "  if len(x_test_dict)==10:\n",
    "    break\n",
    "  if l not in x_test_dict.keys():\n",
    "    x_test_dict[l] = x_test[i]\n",
    "\n",
    "# Convert to list preserving order of classes\n",
    "x_test_each_class = [x_test_dict[i] for i in sorted(x_test_dict)]\n",
    "\n",
    "# Convert to tensor\n",
    "x_test_each_class = np.asarray(x_test_each_class)\n",
    "\n",
    "# Print shape of tensor\n",
    "print(f\"x_test_each_class tensor has shape: {x_test_each_class.shape}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QYNRZ3WkqLfE"
   },
   "source": [
    "Before computing the shap values, make sure that the model is able to correctly classify each one of the examples you just picked:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "3gbdJQKcqN6y"
   },
   "outputs": [],
   "source": [
    "# Compute predictions\n",
    "predictions = model.predict(x_test_each_class)\n",
    "\n",
    "# Apply argmax to get predicted class\n",
    "np.argmax(predictions, axis=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "BNShTEtEqwsQ"
   },
   "source": [
    "Since the test examples are ordered according to the class number and the predictions array is also ordered, the model was able to correctly classify each one of these images."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "OPXkTm_Ll4Ua"
   },
   "source": [
    "## Visualizing Shap Values\n",
    "\n",
    "Now that you have an example of each class, compute the Shap values for each example:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "2Cm_HjmFuQzF"
   },
   "outputs": [],
   "source": [
    "# Compute shap values using DeepExplainer instance. You can safely ignore the warning about the Keras backend method.\n",
    "shap_values = e.shap_values(x_test_each_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nogfmvwsmtbl"
   },
   "source": [
    "Now take a look at the computed shap values. To understand the next illustration have these points in mind:\n",
    "- Positive shap values are denoted by red color and they represent the pixels that contributed to classifying that image as that particular class.\n",
    "- Negative shap values are denoted by blue color and they represent the pixels that contributed to NOT classify that image as that particular class.\n",
    "- Each row contains each one of the test images you computed the shap values for.\n",
    "- Each column represents the ordered categories that the model could choose from. Notice that `shap.image_plot` just makes a copy of the classified image, but you can use the `plot_categories` function you created earlier to show an example of that class for reference."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0t1dWh7rv0ya"
   },
   "outputs": [],
   "source": [
    "# Plot reference column\n",
    "plot_categories(images_dict)\n",
    "\n",
    "# Print an empty line to separate the two plots\n",
    "print()\n",
    "\n",
    "# Plot shap values\n",
    "shap.image_plot(shap_values, -x_test_each_class)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fqblFhpN1ogy"
   },
   "source": [
    "Now take some time to understand what the plot is showing you. Since the model is able to correctly classify each one of these 10 images, it makes sense that the shapley values along the diagonal are the most prevalent. Specially positive values since that is the class the model (correctly) predicted.\n",
    "\n",
    "\n",
    "What else can you derive from this plot? Try focusing on one example. For instance focus on the **coat** which is the fifth class. Looks like the model also had \"reasons\" to classify it as **pullover** or a **shirt**. This can be concluded from the presence of positive shap values for these clases.\n",
    "\n",
    "Let's take a look at the tensor of predictions to double check if this was the case:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fGorLRbV2qz7"
   },
   "outputs": [],
   "source": [
    "# Save the probability of belonging to each class for the fifth element of the set\n",
    "coat_probs = predictions[4]\n",
    "\n",
    "# Order the probabilities in ascending order\n",
    "coat_args = np.argsort(coat_probs)\n",
    "\n",
    "# Reverse the list and get the top 3 probabilities\n",
    "top_coat_args = coat_args[::-1][:3]\n",
    "\n",
    "# Print (ordered) top 3 classes\n",
    "for i in list(top_coat_args):\n",
    "  print(class_names[i])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "m-HDSUT64AdG"
   },
   "source": [
    "Indeed the model selected these 3 classes as the most probable ones for the **coat** image. This makes sense since these objects are similar to each other.\n",
    "\n",
    "\n",
    "Now look at the **t-shirt** which is the first class. This object is very similar to the **pullover** but without the long sleeves. It is not a surprise that white pixels in the area where the long sleeves are present will yield high shap values for classifying as a **t-shirt**. In the same way, white pixels in this area will yield negative shap values for classifying as a **pullover** since the model will expect these pixels to be colored if the item was indeed a **pullover**.\n",
    "\n",
    "\n",
    "You can get a lot of insight repeating this process for all the classes. What other conclusions can you arrive at?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-k1-8uGotB4E"
   },
   "source": [
    "-----------------------------\n",
    "**Congratulations on finishing this ungraded lab!** Now you should have a clearer understanding of what Shapley values are, why they are useful and how to compute them using the `shap` library.\n",
    "\n",
    "Deep Learning models were considered black boxes for a very long time. There is a natural trade off between predicting power and explanaibility in Machine Learning but thanks to the rise of new techniques such as SHapley Additive exPlanations it is easier than never before to explain the outputs of Deep Learning models.\n",
    "\n",
    "\n",
    "**Keep it up!**"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "private_outputs": true,
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
